#pragma once

#include <cassert>
#include <cmath>
#include <iostream>
#include <utility>

namespace simulation {
	template<class K, class V>
	struct AVLTreeNode {
		std::pair<K, V> _kv; //AVLTree中的key和value通过pair一起储存

		//三叉链
		AVLTreeNode<K, V>* _left;
		AVLTreeNode<K, V>* _right;
		AVLTreeNode<K, V>* _parent;

		int _bf; //平衡因子（balance factor）这里为该节点右子树的高度减左子树的高度的值，便于判断一棵树是否平衡（abs(_bf) <= 1）

		AVLTreeNode(const std::pair<K, V>& kv)
			: _kv(kv)
			, _left(nullptr)
			, _right(nullptr)
			, _parent(nullptr)
			, _bf(0)
		{}
	};

	template<class K, class V>
	class AVLTree {
	private:
		typedef AVLTreeNode<K, V> Node;
		Node* _root;

	public:
		AVLTree()
			: _root(nullptr)
		{}

		bool insert(const std::pair<K, V>& kv)
		{
			if (_root == nullptr)
			{
				_root = new Node(kv);
				return true;
			}

			Node* cur = _root;
			Node* parent = nullptr;
			while (cur) //寻找插入位置
			{
				if (kv.first > cur->_kv.first)
				{
					parent = cur;
					cur = cur->_right;
				}
				else if (kv.first < cur->_kv.first)
				{
					parent = cur;
					cur = cur->_left;
				}
				else //存在相同的key值，不允许插入
				{
					return false;
				}
			}

			//已经找到插入位置了，进行插入
			cur = new Node(kv);
			kv.first < parent->_kv.first
				? parent->_left = cur
				: parent->_right = cur;

			cur->_parent = parent; //别忘了三叉链还有_parent

			//插入完成后对其祖先进行平衡因子的调整
			while (parent)
			{
				cur == parent->_right
					? parent->_bf++
					: parent->_bf--;

				if (parent->_bf == 0) //说明原本parent的_bf为-1或1，在插入时是在parent空缺的那一边进行了插入，所以parent的高度没变，因此parent及其祖先不再需要更新
				{
					break;
				}
				else if (parent->_bf == 1 || parent->_bf == -1) //原本parent的平衡因子为0（不可能为-2或2），需要再对parent的祖先进行更新
				{
					cur = parent;
					parent = cur->_parent;
				}
				else if(parent->_bf == 2 || parent->_bf == -2) //parent == 2或-2，说明二叉树不平衡了，需要旋转
				{
					if (parent->_bf == 2)
					{
						cur->_bf == 1
							? rotateL(parent)
							: rotateRL(parent);
					}
					else
					{
						cur->_bf == -1
							? rotateR(parent)
							: rotateLR(parent);
					}

					break;//旋转完了就平衡了，不需要再继续更新了
				}
				else
				{
					assert(false);
				}

			}

			return true;
		}

		void inOrder() const
		{
			_inOrder(_root);
		}

		bool isBalance() const
		{
			return _isBalance(_root);
		}

	private:
		void rotateL(Node* parent) //左单旋
		{
			assert(parent);

			Node* grandparent = parent->_parent;
			Node* subR = parent->_right; //parent的右子树
			Node* subRL = subR->_left; //subR的左子树

			parent->_right = subRL;
			if (subRL)
			{
				subRL->_parent = parent;
			}

			subR->_left = parent;
			parent->_parent = subR;

			if (_root == parent)
			{
				_root = subR;
			}
			else
			{
				parent == grandparent->_left
					? grandparent->_left = subR
					: grandparent->_right = subR;
			}

			subR->_parent = grandparent;

			subR->_bf = parent->_bf = 0;
		}

		void rotateR(Node* parent) //右单旋
		{
			assert(parent);

			Node* grandparent = parent->_parent;
			Node* subL = parent->_left;
			Node* subLR = subL->_right;

			parent->_left = subLR;
			if (subLR)
			{
				subLR->_parent = parent;
			}

			subL->_right = parent;
			parent->_parent = subL;

			if (_root == parent)
			{
				_root = subL;
			}
			else
			{
				parent == grandparent->_left
					? grandparent->_left = subL
					: grandparent->_right = subL;
			}
			subL->_parent = grandparent;

			subL->_bf = parent->_bf = 0;
		}

		void _inOrder(Node* root) const
		{
			if (root == nullptr) return;

			_inOrder(root->_left);
			std::cout << root->_kv.first << ":" << root->_kv.second << std::endl;
			_inOrder(root->_right);
		}

		void rotateLR(Node* parent) //左右双旋
		{
			assert(parent);

			Node* subL = parent->_left;
			Node* subLR = subL->_right;
			int bf = subLR->_bf;

			rotateL(subL);
			rotateR(parent);

			//在进行两次旋转之后parent、subL、subLR的_bf都==0，所以只需要修改subL或者parent就行
			if (bf == 1)
			{
				subL->_bf = -1;
			}
			else if (bf == -1)
			{
				parent->_bf = 1;
			}
			else if (bf == 0)
			{
				//当bf == 0时，parent、subL、subLR的_bf都=0，所以不需要更改
			}
			else
			{
				assert(false);
			}
		}

		void rotateRL(Node* parent)
		{
			assert(parent);

			Node* subR = parent->_right;
			Node* subRL = subR->_left;
			int bf = subRL->_bf;

			rotateR(subR);
			rotateL(parent);

			if (bf == 1)
			{
				parent->_bf = -1;
			}
			else if (bf == -1)
			{
				subR->_bf = 1;
			}
			else if (bf == 0)
			{
				//nothing to do
			}
			else
			{
				assert(false);
			}
		}

		int getHeight(Node* root) const
		{
			if (root == nullptr) return 0;

			int leftHeight = getHeight(root->_left);
			int rightHeight = getHeight(root->_right);

			return leftHeight > rightHeight
				? leftHeight + 1
				: rightHeight + 1;
		}

		bool _isBalance(Node* root) const
		{
			if (root == nullptr) return true;

			int leftHeight = getHeight(root->_left);
			int rightHeight = getHeight(root->_right);

			int bf = rightHeight - leftHeight;
			if (bf != root->_bf)
			{
				std::cout << root->_kv.first << ":" << "balance factor is wrong!" << std::endl;
				return false;
			}

			return std::abs(bf) <= 1
				&& _isBalance(root->_left)
				&& _isBalance(root->_right);
		}
	}; 
}